{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OYlaRwNu7ojq"
      },
      "source": [
        "# **Homework 2 Phoneme Classification**\n",
        "\n",
        "* Slides: https://docs.google.com/presentation/d/1v6HkBWiJb8WNDcJ9_-2kwVstxUWml87b9CnA16Gdoio/edit?usp=sharing\n",
        "* Kaggle: https://www.kaggle.com/c/ml2022spring-hw2\n",
        "* Video: TBA\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mLQI0mNcmM-O",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "7d5b4d81-9438-4d50-8153-cd235c47ee21"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Wed Feb 23 14:42:18 2022       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   30C    P8    29W / 149W |      0MiB / 11441MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "|  No running processes found                                                 |\n",
            "+-----------------------------------------------------------------------------+\n"
          ]
        }
      ],
      "source": [
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KVUGfWTo7_Oj"
      },
      "source": [
        "## Download Data\n",
        "Download data from google drive, then unzip it.\n",
        "\n",
        "You should have\n",
        "- `libriphone/train_split.txt`\n",
        "- `libriphone/train_labels`\n",
        "- `libriphone/test_split.txt`\n",
        "- `libriphone/feat/train/*.pt`: training feature<br>\n",
        "- `libriphone/feat/test/*.pt`:  testing feature<br>\n",
        "\n",
        "after running the following block.\n",
        "\n",
        "> **Notes: if the links are dead, you can download the data directly from [Kaggle](https://www.kaggle.com/c/ml2022spring-hw2/data) and upload it to the workspace, or you can use [the Kaggle API](https://www.kaggle.com/general/74235) to directly download the data into colab.**\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bj5jYXsD9Ef3"
      },
      "source": [
        "### Download train/test metadata"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OzkiMEcC3Foq",
        "outputId": "cc90c16c-ee21-400e-ec08-dfcd422212a6"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "--2022-02-26 14:32:36--  https://github.com/xraychen/shiny-robot/releases/download/v1.0/libriphone.zip\n",
            "Resolving github.com (github.com)... 140.82.114.3\n",
            "Connecting to github.com (github.com)|140.82.114.3|:443... connected.\n",
            "HTTP request sent, awaiting response... 302 Found\n",
            "Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/463868124/343908dd-b2e4-4b8e-b7d6-7f0f040179ce?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20220226%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220226T143237Z&X-Amz-Expires=300&X-Amz-Signature=3e65a2f51128a87a0d129ed899f0959cbc4456397645f3fc6b2f6a4751075250&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=463868124&response-content-disposition=attachment%3B%20filename%3Dlibriphone.zip&response-content-type=application%2Foctet-stream [following]\n",
            "--2022-02-26 14:32:37--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/463868124/343908dd-b2e4-4b8e-b7d6-7f0f040179ce?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20220226%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220226T143237Z&X-Amz-Expires=300&X-Amz-Signature=3e65a2f51128a87a0d129ed899f0959cbc4456397645f3fc6b2f6a4751075250&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=463868124&response-content-disposition=attachment%3B%20filename%3Dlibriphone.zip&response-content-type=application%2Foctet-stream\n",
            "Resolving objects.githubusercontent.com (objects.githubusercontent.com)... 185.199.108.133, 185.199.111.133, 185.199.109.133, ...\n",
            "Connecting to objects.githubusercontent.com (objects.githubusercontent.com)|185.199.108.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 478737370 (457M) [application/octet-stream]\n",
            "Saving to: ‘libriphone.zip’\n",
            "\n",
            "libriphone.zip      100%[===================>] 456.56M  83.6MB/s    in 5.1s    \n",
            "\n",
            "2022-02-26 14:32:42 (90.2 MB/s) - ‘libriphone.zip’ saved [478737370/478737370]\n",
            "\n",
            "feat  test_split.txt  train_labels.txt\ttrain_split.txt\n"
          ]
        }
      ],
      "source": [
        "# Main link\n",
        "!wget -O libriphone.zip \"https://github.com/xraychen/shiny-robot/releases/download/v1.0/libriphone.zip\"\n",
        "\n",
        "# Backup Link 0\n",
        "# !pip install --upgrade gdown\n",
        "# !gdown --id '1o6Ag-G3qItSmYhTheX6DYiuyNzWyHyTc' --output libriphone.zip\n",
        "\n",
        "# Backup link 1\n",
        "# !pip install --upgrade gdown\n",
        "# !gdown --id '1R1uQYi4QpX0tBfUWt2mbZcncdBsJkxeW' --output libriphone.zip\n",
        "\n",
        "# Backup link 2\n",
        "# !wget -O libriphone.zip \"https://www.dropbox.com/s/wqww8c5dbrl2ka9/libriphone.zip?dl=1\"\n",
        "\n",
        "# Backup link 3\n",
        "# !wget -O libriphone.zip \"https://www.dropbox.com/s/p2ljbtb2bam13in/libriphone.zip?dl=1\"\n",
        "\n",
        "!unzip -q libriphone.zip\n",
        "!ls libriphone"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_L_4anls8Drv"
      },
      "source": [
        "### Preparing Data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "po4N3C-AWuWl"
      },
      "source": [
        "**Helper functions to pre-process the training data from raw MFCC features of each utterance.**\n",
        "\n",
        "A phoneme may span several frames and is dependent to past and future frames. \\\n",
        "Hence we concatenate neighboring phonemes for training to achieve higher accuracy. The **concat_feat** function concatenates past and future k frames (total 2k+1 = n frames), and we predict the center frame.\n",
        "\n",
        "Feel free to modify the data preprocess functions, but **do not drop any frame** (if you modify the functions, remember to check that the number of frames are the same as mentioned in the slides)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IJjLT8em-y9G"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import random\n",
        "import pandas as pd\n",
        "import torch\n",
        "from tqdm import tqdm\n",
        "\n",
        "def load_feat(path):\n",
        "    feat = torch.load(path)\n",
        "    return feat\n",
        "\n",
        "def shift(x, n):\n",
        "    if n < 0:\n",
        "        left = x[0].repeat(-n, 1)\n",
        "        right = x[:n]\n",
        "\n",
        "    elif n > 0:\n",
        "        right = x[-1].repeat(n, 1)\n",
        "        left = x[n:]\n",
        "    else:\n",
        "        return x\n",
        "\n",
        "    return torch.cat((left, right), dim=0)\n",
        "\n",
        "def concat_feat(x, concat_n):\n",
        "    assert concat_n % 2 == 1 # n must be odd\n",
        "    if concat_n < 2:\n",
        "        return x\n",
        "    seq_len, feature_dim = x.size(0), x.size(1)\n",
        "    x = x.repeat(1, concat_n) \n",
        "    x = x.view(seq_len, concat_n, feature_dim).permute(1, 0, 2) # concat_n, seq_len, feature_dim\n",
        "    mid = (concat_n // 2)\n",
        "    for r_idx in range(1, mid+1):\n",
        "        x[mid + r_idx, :] = shift(x[mid + r_idx], r_idx)\n",
        "        x[mid - r_idx, :] = shift(x[mid - r_idx], -r_idx)\n",
        "\n",
        "    return x.permute(1, 0, 2).view(seq_len, concat_n * feature_dim)\n",
        "\n",
        "def preprocess_data(split, feat_dir, phone_path, concat_nframes, train_ratio=0.8, train_val_seed=1337):\n",
        "    class_num = 41 # NOTE: pre-computed, should not need change\n",
        "    mode = 'train' if (split == 'train' or split == 'val') else 'test'\n",
        "\n",
        "    label_dict = {}\n",
        "    if mode != 'test':\n",
        "      phone_file = open(os.path.join(phone_path, f'{mode}_labels.txt')).readlines()\n",
        "\n",
        "      for line in phone_file:\n",
        "          line = line.strip('\\n').split(' ')\n",
        "          label_dict[line[0]] = [int(p) for p in line[1:]]\n",
        "\n",
        "    if split == 'train' or split == 'val':\n",
        "        # split training and validation data\n",
        "        usage_list = open(os.path.join(phone_path, 'train_split.txt')).readlines()\n",
        "        random.seed(train_val_seed)\n",
        "        random.shuffle(usage_list)\n",
        "        percent = int(len(usage_list) * train_ratio)\n",
        "        usage_list = usage_list[:percent] if split == 'train' else usage_list[percent:]\n",
        "    elif split == 'test':\n",
        "        usage_list = open(os.path.join(phone_path, 'test_split.txt')).readlines()\n",
        "    else:\n",
        "        raise ValueError('Invalid \\'split\\' argument for dataset: PhoneDataset!')\n",
        "\n",
        "    usage_list = [line.strip('\\n') for line in usage_list]\n",
        "    print('[Dataset] - # phone classes: ' + str(class_num) + ', number of utterances for ' + split + ': ' + str(len(usage_list)))\n",
        "\n",
        "    max_len = 3000000\n",
        "    X = torch.empty(max_len, 39 * concat_nframes)\n",
        "    if mode != 'test':\n",
        "      y = torch.empty(max_len, dtype=torch.long)\n",
        "\n",
        "    idx = 0\n",
        "    for i, fname in tqdm(enumerate(usage_list)):\n",
        "        feat = load_feat(os.path.join(feat_dir, mode, f'{fname}.pt'))\n",
        "        cur_len = len(feat)\n",
        "        feat = concat_feat(feat, concat_nframes)\n",
        "        if mode != 'test':\n",
        "          label = torch.LongTensor(label_dict[fname])\n",
        "\n",
        "        X[idx: idx + cur_len, :] = feat\n",
        "        if mode != 'test':\n",
        "          y[idx: idx + cur_len] = label\n",
        "\n",
        "        idx += cur_len\n",
        "\n",
        "    X = X[:idx, :]\n",
        "    if mode != 'test':\n",
        "      y = y[:idx]\n",
        "\n",
        "    print(f'[INFO] {split} set')\n",
        "    print(X.shape)\n",
        "    if mode != 'test':\n",
        "      print(y.shape)\n",
        "      return X, y\n",
        "    else:\n",
        "      return X\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "us5XW_x6udZQ"
      },
      "source": [
        "## Define Dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fjf5EcmJtf4e"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torch.utils.data import Dataset\n",
        "from torch.utils.data import DataLoader\n",
        "\n",
        "class LibriDataset(Dataset):\n",
        "    def __init__(self, X, y=None):\n",
        "        self.data = X\n",
        "        if y is not None:\n",
        "            self.label = torch.LongTensor(y)\n",
        "        else:\n",
        "            self.label = None\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        if self.label is not None:\n",
        "            return self.data[idx], self.label[idx]\n",
        "        else:\n",
        "            return self.data[idx]\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.data)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IRqKNvNZwe3V"
      },
      "source": [
        "## Define Model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bg-GRd7ywdrL"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "class BasicBlock(nn.Module):\n",
        "    def __init__(self, input_dim, output_dim):\n",
        "        super(BasicBlock, self).__init__()\n",
        "\n",
        "        self.block = nn.Sequential(\n",
        "            nn.Linear(input_dim, output_dim),\n",
        "            nn.ReLU(),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.block(x)\n",
        "        return x\n",
        "\n",
        "\n",
        "class Classifier(nn.Module):\n",
        "    def __init__(self, input_dim, output_dim=41, hidden_layers=1, hidden_dim=256):\n",
        "        super(Classifier, self).__init__()\n",
        "\n",
        "        self.fc = nn.Sequential(\n",
        "            BasicBlock(input_dim, hidden_dim),\n",
        "            *[BasicBlock(hidden_dim, hidden_dim) for _ in range(hidden_layers)],\n",
        "            nn.Linear(hidden_dim, output_dim)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.fc(x)\n",
        "        return x"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Hyper-parameters"
      ],
      "metadata": {
        "id": "TlIq8JeqvvHC"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# data prarameters\n",
        "concat_nframes = 1              # the number of frames to concat with, n must be odd (total 2k+1 = n frames)\n",
        "train_ratio = 0.8               # the ratio of data used for training, the rest will be used for validation\n",
        "\n",
        "# training parameters\n",
        "seed = 0                        # random seed\n",
        "batch_size = 512                # batch size\n",
        "num_epoch = 5                   # the number of training epoch\n",
        "learning_rate = 0.0001          # learning rate\n",
        "model_path = './model.ckpt'     # the path where the checkpoint will be saved\n",
        "\n",
        "# model parameters\n",
        "input_dim = 39 * concat_nframes # the input dim of the model, you should not change the value\n",
        "hidden_layers = 1               # the number of hidden layers\n",
        "hidden_dim = 256                # the hidden dim"
      ],
      "metadata": {
        "id": "iIHn79Iav1ri"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Prepare dataset and model"
      ],
      "metadata": {
        "id": "IIUFRgG5yoDn"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import gc\n",
        "\n",
        "# preprocess data\n",
        "train_X, train_y = preprocess_data(split='train', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)\n",
        "val_X, val_y = preprocess_data(split='val', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes, train_ratio=train_ratio)\n",
        "\n",
        "# get dataset\n",
        "train_set = LibriDataset(train_X, train_y)\n",
        "val_set = LibriDataset(val_X, val_y)\n",
        "\n",
        "# remove raw feature to save memory\n",
        "del train_X, train_y, val_X, val_y\n",
        "gc.collect()\n",
        "\n",
        "# get dataloader\n",
        "train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)\n",
        "val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)"
      ],
      "metadata": {
        "id": "c1zI3v5jyrDn",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "3ea2823a-83f3-42d9-ef05-2f2c002f9538"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[Dataset] - # phone classes: 41, number of utterances for train: 3428\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "3428it [00:01, 2097.23it/s]\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[INFO] train set\n",
            "torch.Size([2116368, 39])\n",
            "torch.Size([2116368])\n",
            "[Dataset] - # phone classes: 41, number of utterances for val: 858\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "858it [00:00, 2087.91it/s]"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[INFO] val set\n",
            "torch.Size([527790, 39])\n",
            "torch.Size([527790])\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CfRUEgC0GxUV",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "f9804711-72b1-4717-896b-821a300cfe87"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "DEVICE: cuda:0\n"
          ]
        }
      ],
      "source": [
        "device = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n",
        "print(f'DEVICE: {device}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "88xPiUnm0tAd"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "#fix seed\n",
        "def same_seeds(seed):\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed(seed)\n",
        "        torch.cuda.manual_seed_all(seed)  \n",
        "    np.random.seed(seed)  \n",
        "    torch.backends.cudnn.benchmark = False\n",
        "    torch.backends.cudnn.deterministic = True"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QTp3ZXg1yO9Y"
      },
      "outputs": [],
      "source": [
        "# fix random seed\n",
        "same_seeds(seed)\n",
        "\n",
        "# create model, define a loss function, and optimizer\n",
        "model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)\n",
        "criterion = nn.CrossEntropyLoss() \n",
        "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Training"
      ],
      "metadata": {
        "id": "pwWH1KIqzxEr"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CdMWsBs7zzNs",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "17922ad2-a319-4253-8783-3e4939d0a7cf"
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 4134/4134 [00:24<00:00, 168.54it/s]\n",
            "100%|██████████| 1031/1031 [00:03<00:00, 276.80it/s]\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[001/005] Train Acc: 0.421913 Loss: 2.086623 | Val Acc: 0.440590 loss: 1.971902\n",
            "saving model with acc 0.441\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 4134/4134 [00:24<00:00, 166.62it/s]\n",
            "100%|██████████| 1031/1031 [00:03<00:00, 267.27it/s]\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[002/005] Train Acc: 0.449009 Loss: 1.934465 | Val Acc: 0.449662 loss: 1.926844\n",
            "saving model with acc 0.450\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 4134/4134 [00:25<00:00, 163.30it/s]\n",
            "100%|██████████| 1031/1031 [00:03<00:00, 272.71it/s]\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[003/005] Train Acc: 0.455054 Loss: 1.904332 | Val Acc: 0.453874 loss: 1.908874\n",
            "saving model with acc 0.454\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 4134/4134 [00:22<00:00, 179.96it/s]\n",
            "100%|██████████| 1031/1031 [00:03<00:00, 279.00it/s]\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[004/005] Train Acc: 0.458398 Loss: 1.887975 | Val Acc: 0.456089 loss: 1.896490\n",
            "saving model with acc 0.456\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 4134/4134 [00:23<00:00, 178.26it/s]\n",
            "100%|██████████| 1031/1031 [00:03<00:00, 275.13it/s]"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[005/005] Train Acc: 0.460776 Loss: 1.876422 | Val Acc: 0.457841 loss: 1.889746\n",
            "saving model with acc 0.458\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        }
      ],
      "source": [
        "best_acc = 0.0\n",
        "for epoch in range(num_epoch):\n",
        "    train_acc = 0.0\n",
        "    train_loss = 0.0\n",
        "    val_acc = 0.0\n",
        "    val_loss = 0.0\n",
        "    \n",
        "    # training\n",
        "    model.train() # set the model to training mode\n",
        "    for i, batch in enumerate(tqdm(train_loader)):\n",
        "        features, labels = batch\n",
        "        features = features.to(device)\n",
        "        labels = labels.to(device)\n",
        "        \n",
        "        optimizer.zero_grad() \n",
        "        outputs = model(features) \n",
        "        \n",
        "        loss = criterion(outputs, labels)\n",
        "        loss.backward() \n",
        "        optimizer.step() \n",
        "        \n",
        "        _, train_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n",
        "        train_acc += (train_pred.detach() == labels.detach()).sum().item()\n",
        "        train_loss += loss.item()\n",
        "    \n",
        "    # validation\n",
        "    if len(val_set) > 0:\n",
        "        model.eval() # set the model to evaluation mode\n",
        "        with torch.no_grad():\n",
        "            for i, batch in enumerate(tqdm(val_loader)):\n",
        "                features, labels = batch\n",
        "                features = features.to(device)\n",
        "                labels = labels.to(device)\n",
        "                outputs = model(features)\n",
        "                \n",
        "                loss = criterion(outputs, labels) \n",
        "                \n",
        "                _, val_pred = torch.max(outputs, 1) \n",
        "                val_acc += (val_pred.cpu() == labels.cpu()).sum().item() # get the index of the class with the highest probability\n",
        "                val_loss += loss.item()\n",
        "\n",
        "            print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f} | Val Acc: {:3.6f} loss: {:3.6f}'.format(\n",
        "                epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader), val_acc/len(val_set), val_loss/len(val_loader)\n",
        "            ))\n",
        "\n",
        "            # if the model improves, save a checkpoint at this epoch\n",
        "            if val_acc > best_acc:\n",
        "                best_acc = val_acc\n",
        "                torch.save(model.state_dict(), model_path)\n",
        "                print('saving model with acc {:.3f}'.format(best_acc/len(val_set)))\n",
        "    else:\n",
        "        print('[{:03d}/{:03d}] Train Acc: {:3.6f} Loss: {:3.6f}'.format(\n",
        "            epoch + 1, num_epoch, train_acc/len(train_set), train_loss/len(train_loader)\n",
        "        ))\n",
        "\n",
        "# if not validating, save the last epoch\n",
        "if len(val_set) == 0:\n",
        "    torch.save(model.state_dict(), model_path)\n",
        "    print('saving model at last epoch')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ab33MxosWLmG",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "911e8c9b-fc0f-4591-b0f6-311a1231c5e2"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "50"
            ]
          },
          "execution_count": null,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "del train_loader, val_loader\n",
        "gc.collect()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1Hi7jTn3PX-m"
      },
      "source": [
        "## Testing\n",
        "Create a testing dataset, and load model from the saved checkpoint."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VOG1Ou0PGrhc",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "abaaa25b-a93c-49b0-d228-9eca1e2ab2e0"
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[Dataset] - # phone classes: 41, number of utterances for test: 1078\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "1078it [00:00, 2496.92it/s]"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[INFO] test set\n",
            "torch.Size([646268, 39])\n"
          ]
        },
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "\n"
          ]
        }
      ],
      "source": [
        "# load data\n",
        "test_X = preprocess_data(split='test', feat_dir='./libriphone/feat', phone_path='./libriphone', concat_nframes=concat_nframes)\n",
        "test_set = LibriDataset(test_X, None)\n",
        "test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ay0Fu8Ovkdad",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "e5b20aa7-4d8b-43a9-e068-f5c89706a360"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "<All keys matched successfully>"
            ]
          },
          "execution_count": null,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# load model\n",
        "model = Classifier(input_dim=input_dim, hidden_layers=hidden_layers, hidden_dim=hidden_dim).to(device)\n",
        "model.load_state_dict(torch.load(model_path))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zp-DV1p4r7Nz"
      },
      "source": [
        "Make prediction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "84HU5GGjPqR0",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "cebd6694-8f74-44ff-f922-96ca4385acb8"
      },
      "outputs": [
        {
          "metadata": {
            "tags": null
          },
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 1263/1263 [00:02<00:00, 439.04it/s]\n"
          ]
        }
      ],
      "source": [
        "test_acc = 0.0\n",
        "test_lengths = 0\n",
        "pred = np.array([], dtype=np.int32)\n",
        "\n",
        "model.eval()\n",
        "with torch.no_grad():\n",
        "    for i, batch in enumerate(tqdm(test_loader)):\n",
        "        features = batch\n",
        "        features = features.to(device)\n",
        "\n",
        "        outputs = model(features)\n",
        "\n",
        "        _, test_pred = torch.max(outputs, 1) # get the index of the class with the highest probability\n",
        "        pred = np.concatenate((pred, test_pred.cpu().numpy()), axis=0)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wyZqy40Prz0v"
      },
      "source": [
        "Write prediction to a CSV file.\n",
        "\n",
        "After finish running this block, download the file `prediction.csv` from the files section on the left-hand side and submit it to Kaggle."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GuljYSPHcZir"
      },
      "outputs": [],
      "source": [
        "with open('prediction.csv', 'w') as f:\n",
        "    f.write('Id,Class\\n')\n",
        "    for i, y in enumerate(pred):\n",
        "        f.write('{},{}\\n'.format(i, y))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "ML2022Spring - HW2.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}